import os
import json
import time
from datetime import datetime, timedelta
import string
import threading

import pandas as pd
import pymysql
import redis
from tqsdk import TqApi, TqAuth
from tqsdk.ta import ATR
import matplotlib.pyplot as plt
import mplfinance as mpf
import matplotlib.animation as animation

plt.rcParams["font.sans-serif"] = ["SimHei"]
plt.rcParams['axes.unicode_minus'] = False
# region 全局变量定义
DATA = {}  # 用来存储所有已有的k线数据
STOP_MARK = False  # 用来空时TQSDK和mplfinnace是否停止循环
SYMBOL = ''  # 用来确定显示的是哪个合约
CFG = {}
MULTIPLIER = ''  # 用于存放当前symbol的合约乘数


# endregion


class Draw:
    def __init__(self, CFG):
        self.cfg = CFG
        self.df = pd.DataFrame()
        self.time = time.time()
        self.fig = mpf.figure(style='binance', figsize=(13, 6))
        self.ax1 = self.fig.add_subplot(2, 1, 1)
        self.ax2 = self.fig.add_subplot(3, 1, 3)
        self.ani = None
        self.stg = None
        self.normal_label_font = {'fontname': 'pingfang HK',
                                  'size': '12',
                                  'color': 'black',
                                  'weight': 'normal',
                                  'va': 'bottom',
                                  'ha': 'right'}
        self.text_stg = self.fig.text(0.1, 0.9, '策略：', **self.normal_label_font)
        self.text_price = self.fig.text(0.25, 0.9, '最新价：', **self.normal_label_font)
        self.text_vol = self.fig.text(0.4, 0.9, '成交量：', **self.normal_label_font)
        self.text_tech = self.fig.text(0.7, 0, '', **self.normal_label_font)
        self.text_time = self.fig.text(0.9, 0, '', **self.normal_label_font)

        try:
            self.r = redis.Redis(host=self.cfg['redis_host'],
                                 password=self.cfg['redis_passwd'],
                                 port=self.cfg['redis_port'],
                                 db=self.cfg['redis_db'])
        except Exception as e:
            print(f'链接redis出现问题： {e}')
            raise
        # mpf.plot(self.df, ax=self.ax1, volume=self.ax2, type='candle')

    def run(self):
        self.ani = animation.FuncAnimation(self.fig, self.update, interval=25)
        mpf.show()

    def draw_atr(self, df, period):
        atr = ATR(df['close'], period)
        df[f'atr_{period}'] = atr.atr
        return df

    def view_sma(self, ):
        global DATA, SYMBOL

        df = DATA.get(SYMBOL, pd.DataFrame()).copy(deep=True)

        if not df.empty:
            view = json.loads(self.r.get(self.stg))
            ma_long = view['ma_long']
            ma_short = view['ma_short']
            df[f"ma_{ma_short}"] = df['close'].rolling(ma_short).mean()
            df[f'ma_{ma_long}'] = df['close'].rolling(ma_long).mean()

            df = df.iloc[-50:, :]
            # print(df.tail(2))
            df['date'] = pd.to_datetime(df["datetime"], origin='1970-01-01 08:00:00')
            df.set_index('date', inplace=True)
            add_plot = []
            add_plot.append(mpf.make_addplot(df[f'ma_{ma_short}'], type='line', color='red', ax=self.ax1))
            add_plot.append(mpf.make_addplot(df[f'ma_{ma_long}'], type='line', color='green', ax=self.ax1))
            self.ax1.clear()
            self.ax2.clear()
            self.ax1.set_ylabel('SMA')
            mpf.plot(df, ax=self.ax1, volume=self.ax2, addplot=add_plot, type='candle',
                     style='binance',
                     datetime_format='%H:%M',
                     scale_width_adjustment=dict(volume=0.5, candle=1.15, lines=0.3))
            self.text_stg.set_text('策略：sma')
            self.text_price.set_text(f'最新价：{df["close"].iat[-1]}')
            # self.text_vol.set_text(f'成交量：{df["volume"].iat[-1]}')
            short = f"ma_{ma_short}"
            long = f"ma_{ma_long}"
            self.text_tech.set_text(f"sma_short:{round(df[short].iat[-1], 3)}, sma_long:{round(df[long].iat[-1], 3)}")
            self.fig.canvas.manager.set_window_title(f'QM {SYMBOL}')

    def view_turtle(self):
        global DATA, SYMBOL, MULTIPLIER
        if not SYMBOL.split('_')[0]:
            print('turtle redis配置中，没有填写标的！填写后才可以显示！')
        df = DATA.get(SYMBOL, pd.DataFrame()).copy(deep=True)

        if not df.empty:
            view = json.loads(self.r.get(self.stg))
            long = view['long']
            short = view['short']
            df[f"donchain_up_{short}"] = df['high'].rolling(short).max()
            df[f'donchain_up_{long}'] = df['high'].rolling(long).max()
            df[f"donchain_down_{short}"] = df['low'].rolling(short).min()
            df[f'donchain_down_{long}'] = df['low'].rolling(long).min()

            atr = ATR(df, view['atr'])
            df[f"atr_{view['atr']}"] = atr.atr
            df = df.iloc[-50:, :]
            df['date'] = pd.to_datetime(df["datetime"], origin='1970-01-01 08:00:00')
            df.set_index('date', inplace=True)
            add_plot = []
            add_plot.append(mpf.make_addplot(df[[f"donchain_up_{short}",
                                                 f"donchain_up_{long}",
                                                 f"donchain_down_{short}",
                                                 f"donchain_down_{long}"]], ax=self.ax1))
            add_plot.append(mpf.make_addplot(df[f"atr_{view['atr']}"], ax=self.ax2))
            self.ax1.clear()
            self.ax2.clear()
            self.ax1.set_ylabel('turtle')
            self.ax2.set_ylabel('atr')
            mpf.plot(df, ax=self.ax1, addplot=add_plot, type='candle', style='binance',
                     datetime_format='%Y-%m-%d')
            self.text_stg.set_text('策略：Turtle')
            self.text_price.set_text(f'最新价：{df["close"].iat[-1]}')
            self.text_vol.set_text(f'成交量：{df["volume"].iat[-1]}')
            atr_col = f"atr_{view['atr']}"
            ul = f"donchain_up_{long}"
            us = f"donchain_up_{short}"
            ll = f"donchain_down_{long}"
            ls = f"donchain_down_{short}"
            atr = f"atr_{view['atr']}"
            s = f"""UL:{df[ul].iat[-1]}, US:{df[us].iat[-1]}, LL:{df[ll].iat[-1]}, LS:{df[ls].iat[-1]}, ATR:{df[atr].iat[-1]}, HIGH:{df['high'].iat[-1]}, LOW:{df['low'].iat[-1]}, MULT:{MULTIPLIER}"""
            self.text_tech.set_text(s)
            self.fig.canvas.manager.set_window_title(f'QM {SYMBOL}')

    def update(self, itv):
        global SYMBOL, DATA
        if time.time() - self.time > 0.5:  # 半秒更新一次

            if self.r.exists(self.cfg['redis_display_stg']):
                self.stg = self.r.get(self.cfg['redis_display_stg']).decode('utf-8')
                if 'sma' in self.stg:
                    self.view_sma()
                elif 'turtle' in self.stg:
                    self.view_turtle()
                else:
                    print(f'未知的view！{self.stg}')
            self.text_time.set_text(f'{time.ctime()}')
        time.sleep(0.1)


class QuantMasterKernel:

    def __init__(self, CFG):
        self.cfg = CFG
        self.count = 0
        self.pos_symbol = []
        self.redis_symbol = []
        self.general_ticker_info = None
        self.api = None
        self.display_symbol = {}

        try:
            self.conn = pymysql.connect(host=self.cfg['mysql_host'],
                                        port=self.cfg['mysql_port'],
                                        user=self.cfg['mysql_user'],
                                        password=self.cfg['mysql_passwd'],
                                        db=self.cfg['mysql_database'])
        except Exception as e:
            print(f'连接数据库出现问题：{e}')
            raise
        else:
            self.cur = self.conn.cursor()

        try:
            self.r = redis.Redis(host=self.cfg['redis_host'],
                                 password=self.cfg['redis_passwd'],
                                 port=self.cfg['redis_port'],
                                 db=self.cfg['redis_db'])
        except Exception as e:
            print(f'链接redis出现问题： {e}')
            raise

        self.general_ticker_info = self.get_general_ticker_info()
        # 初始化一个df映射我使用的全小写的symbol和tqsdk的symbol
        self.trans_symbol_df = self.build_symbol_trans_df()
        self.api = self.connect_tqsdk()

        # self.redis_symbol = self.get_redis_symbol()
        # self.pos_symbol = self.get_mysql_symbol()
        # self.subscribe_kline()

    def get_mysql_symbol(self):

        sql = f"SELECT symbol FROM {self.cfg['mysql_table']}"
        self.cur.execute(sql)
        self.pos_symbol = self.cur.fetchall()
        self.conn.commit()
        return self.pos_symbol

    def get_general_ticker_info(self):
        columns = [
            'contract_name', 'index_name', 'minmum_move', 'contract_multiplier', 'margin', 'flat_rule',
            'contract_change_rule', 'ch_name', 'close_time', 'main_contract', 'type', 'sub_type'
        ]
        sql = f"select * from {self.cfg['mysql_general_ticker_info_table']}"
        self.cur.execute(sql)
        self.conn.commit()
        return pd.DataFrame(self.cur.fetchall(), columns=columns)

    def get_redis_symbol(self):

        self.redis_symbol = []
        for key in [self.cfg['redis_temp_key'], self.cfg['redis_stable_key']]:
            if self.r.exists(key):
                tp = self.r.lrange(key, 0, -1)
                for symbol in tp:
                    self.redis_symbol.append(json.loads(symbol))  # 返回一个json {"symbol": "", "period": ""}
        return self.redis_symbol

    def connect_tqsdk(self):
        return TqApi(auth=TqAuth(self.cfg["tq_username"], self.cfg["tq_passwd"]))

    def subscribe_kline(self):
        global DATA
        for symbol in self.pos_symbol:
            symbol = self.symbol2tq(symbol[0])
            if symbol + '_' + '86400' not in DATA:
                DATA[symbol + '_' + '86400'] = self.api.get_kline_serial(symbol, 86400)

        for item in self.redis_symbol:
            symbol = item['symbol']
            if not symbol:  # redis有可能返回空值
                continue
            symbol = self.symbol2tq(symbol)
            if not symbol:  # 初始化会创建空的标的，需要剔除掉。
                continue
            period = item['period']
            count = item.get('count', 200)  # 默认按照200订阅，本来不传参数也是200
            if symbol + '_' + f'{period}' not in DATA:
                DATA[symbol + '_' + f'{period}'] = self.api.get_kline_serial(symbol, period, count)

    def build_symbol_trans_df(self):
        self.general_ticker_info['symbol'] = \
            self.general_ticker_info['contract_name'].apply(lambda x: x.split('.')[-1].lower())

        return self.general_ticker_info.copy(deep=True)

    # mysql的所有标的代码全部用小写，不写交易所代码，郑商所的代码不写十年比如现在的2022-8-12的合约为 "xx208"而其他的是 "xx2208"
    def symbol2tq(self, symbol):
        symbol_strip = symbol.strip(string.digits)
        if symbol_strip in self.trans_symbol_df['symbol'].values:  # 应该都在，不在报错。
            contract_name = self.trans_symbol_df.loc[symbol_strip == self.trans_symbol_df['symbol'], 'contract_name']
            contract_name = contract_name.iloc[0]
            exchange = contract_name.split('.')[1].strip('m@')
            _ = contract_name.split('.')[-1]
            return exchange + '.' + (symbol if _.islower() else symbol.upper())
        else:
            print(f'合约不存在于general tick info 中，请检查{symbol}')

    def get_contract_multiplier(self, symbol):
        symbol_strip = symbol.strip(string.digits)
        if symbol_strip in self.trans_symbol_df['symbol'].values:  # 应该都在，不在报错。
            muti = self.trans_symbol_df.loc[symbol_strip == self.trans_symbol_df['symbol'], 'contract_multiplier'].iat[0]
            return muti
        else:
            print(f'合约不存在于general tick info 中，请检查{symbol}')

    def get_display_symbol(self):
        try:
            self.stg = self.r.get(self.cfg['redis_display_stg'])
            dis = self.r.get(self.stg)
            self.display_symbol = json.loads(dis)
        except Exception as e:
            print(e)

    def create_redis_temp_symbol(self):
        def get_reset_seconds():
            now = datetime.now()
            today_begin = datetime(now.year, now.month, now.day, 0, 0, 0)
            tomorrow_begin = today_begin + timedelta(days=1)
            tomorrow_16 = tomorrow_begin + timedelta(hours=16)
            rest_seconds = (tomorrow_16 - now).seconds
            return rest_seconds

        if self.r.exists(self.cfg['redis_temp_key']):  # 已经存在不需要管
            return
        else:
            v = {"symbol": "", "period": ""}
            self.r.lpush(self.cfg['redis_temp_key'], json.dumps(v))

            self.r.expire(self.cfg['redis_temp_key'], get_reset_seconds())

    def update(self):
        global SYMBOL, MULTIPLIER
        while True:
            self.api.wait_update(time.time() + 10)

            if time.time() - self.count > 0.5:  # 每半秒
                self.count = time.time()
                self.get_display_symbol()
                SYMBOL = self.display_symbol['symbol'] + '_' + str(self.display_symbol['period'])
                MULTIPLIER = self.get_contract_multiplier(self.display_symbol['symbol'])
                if self.display_symbol['symbol'] and SYMBOL not in DATA:
                    try:
                        DATA[SYMBOL] = self.api.get_kline_serial(self.symbol2tq(self.display_symbol['symbol']),
                                                                 self.display_symbol['period'],
                                                                 self.display_symbol['count'])
                    except Exception as e:
                        print(e)

                if STOP_MARK:
                    self.api.close()
                    self.r.close()
                    self.conn.close()
                    break


if __name__ == '__main__':
    config_path = '../config/quant_master_config.json'
    if os.path.exists(config_path):
        with open(config_path) as f:
            CFG = json.load(f)
    elif os.path.exists('./quant_master_config.json'):
        with open('./quant_master_config.json') as f:
            CFG = json.load(f)
    else:
        CFG = {
            "tq_username": "18064114200",
            "tq_passwd": "",
            "mysql_host": "localhost",
            "mysql_port": 3306,
            "mysql_passwd": "",
            "mysql_user": "root",
            "mysql_database": "",
            "mysql_table": "turtle_pos",
            "mysql_general_ticker_info_table": "general_ticker_info",
            "redis_temp_key": "temp_symbol",
            "redis_stable_key": "stable_symbol",  # stable 暂时用来放sma的rb，temp用来放创新高新低的。
            "redis_display_key": "display_symbol",
            "redis_display_stg": "display_stg",
            "redis_host": "localhost",
            "redis_port": 6379,
            "redis_passwd": "",
            "redis_db": 0
        }
        with open('./quant_master_config.json', 'w') as f:
            json.dump(CFG, f)
        print('请填写配置文件后再启动！')
        exit(0)

    QM = QuantMasterKernel(CFG)
    qm = threading.Thread(target=QM.update)
    qm.start()
    time.sleep(9)
    kline = Draw(CFG)
    kline.run()
    kline.r.close()
    STOP_MARK = True
    qm.join()
